#!/usr/bin/env python
import numpy as np
import tensorflow as tf
import dataIO as d
import vae_model as vm
import time

n_epochs        =   20000
batch_size      =   5
beta            =   0.5
validation_step =   100
side_len        =   64                   

MODEL_DIRECTORY = "store/model.ckpt-z-15"
LOGS_DIRECTORY = "store/model.ckpt-z-15"   


def trainGAN():

    # initilize the graph
    tf.reset_default_graph() 

    x_vector = tf.placeholder(shape=[None,side_len,side_len,side_len,1],dtype=tf.float32) 
    x_flat = tf.reshape(x_vector, shape=[-1, side_len * side_len * side_len])

    # compute mean and std by encoder and convert to z-latent   
    z_latent, mean, stddev = vm.encoder(x_vector, phase_train=True, reuse=False, training=False)       

    # compute out-image resulting from z-latent     
    out_image = vm.decoder(z_latent,batch_size, phase_train=True, reuse=False, training=False) 
    x_out_flat = tf.reshape(out_image, shape=[-1, side_len * side_len * side_len])

    # Compute the encoder and decoder loss
    with tf.name_scope("ELOSS"):
        vae_loss = tf.reduce_mean(0.5 * (tf.square(mean) + tf.square(stddev) - 2.0 * tf.log(stddev + 1e-8) - 1.0))
    tf.summary.scalar("ELOSS", vae_loss)

    with tf.name_scope("DLOSS"):
        rec_loss1 = tf.reduce_sum(tf.squared_difference(x_out_flat , x_flat),1)
        rec_loss = tf.reduce_mean(rec_loss1)
    tf.summary.scalar("DLOSS", rec_loss)

    with tf.name_scope("TOTAL_LOSS"):
        loss=vae_loss + rec_loss
    tf.summary.scalar("TOTAL_LOSS", loss)   

    # compute the PNSR and SSIM
    with tf.name_scope("PSNR"):
        correct_prediction = tf.reduce_mean(tf.square(x_vector - out_image)) 
        accuracy = -10.0*tf.log(correct_prediction)/tf.log(10.)
    tf.summary.scalar("PSNR", accuracy)  
    
    with tf.name_scope("Structure_similarity_index"):
        mean_y_=tf.reduce_mean(x_vector)
        mean_y=tf.reduce_mean(out_image)
        var_y_=tf.reduce_mean((x_vector-mean_y_)**2)
        var_y=tf.reduce_mean((out_image-mean_y)**2)
        covy_y=tf.reduce_mean((out_image-mean_y)*(x_vector-mean_y_))
        SSIM_up=(2*mean_y_*mean_y+1e-6)*(2*covy_y+1e-6)
        SSIM_down=(mean_y_**2+mean_y**2+1e-6)*(var_y_+var_y+1e-6)
        SSIM=SSIM_up/SSIM_down
    tf.summary.scalar("structure_similarity_index", SSIM)

    # optimizer defination    
    with tf.name_scope("ADAM"):
        # define a variable that store a writeable tensor value persisting between Session.run calls
        para_decoder = [var for var in tf.trainable_variables() if any(x in var.name for x in ['wd', 'bd', 'decoder'])]
        para_encoder = [var for var in tf.trainable_variables() if any(x in var.name for x in ['we', 'be', 'encoder'])]  

        batch_e = tf.Variable(0) 
        learning_rate_e = tf.train.exponential_decay(
                    0.001,                 # -5  Base learning rate.
                    batch_e ,   # Current index into the dataset.
                    50,          # Decay step.
                    0.95,                 # 0.99 Decay rate.
                    staircase=True)    

        batch_d = tf.Variable(0) 
        learning_rate_d = tf.train.exponential_decay(
                    0.001,                 # -5  Base learning rate.
                    batch_d ,   # Current index into the dataset.
                    50,          # Decay step.
                    0.98,                 # 0.99 Decay rate.
                    staircase=True)    

        # only update the weights for the discriminator network
        optimizer_op_encoder = tf.train.AdamOptimizer(learning_rate=learning_rate_e,beta1=beta).minimize(vae_loss,var_list=para_encoder,global_step=batch_e)
        # only update the weights for the generator network
        optimizer_op_decoder = tf.train.AdamOptimizer(learning_rate=learning_rate_d,beta1=beta).minimize(rec_loss,var_list=para_decoder,global_step=batch_d)
    
    tf.summary.scalar('learning_rate_e', learning_rate_e)
    tf.summary.scalar('learning_rate_d', learning_rate_d)      

    # Merge all summaries into a single op
    merged_summary_op = tf.summary.merge_all()
    
    # Add ops to save and restore all the variables
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer()) 

    # op to write logs to Tensorboard
    summary_writer = tf.summary.FileWriter(LOGS_DIRECTORY, graph=tf.get_default_graph())

    # Save the maximum accuracy value for validation data
    max_acc = 0.
    
    mark=0
    pnsr_vali_store=np.zeros(200)
    ssim_vali_store=np.zeros(200)

    volumes = d.getAll(side_len)
    volume1=volumes[:5000,:]  # for training
    volume2=volumes[5000:,:]  # for validation
    
    volume_train=volume1.reshape(5000,side_len,side_len,side_len,1) # train data
    volume_validation=volume2.reshape(1000,side_len,side_len,side_len,1) # validation data
    
    for epoch in range(n_epochs):
      
        idx = np.random.randint(len(volume_train), size=batch_size)
        x = volume_train[idx]

        loss_R,rec_loss_R,train_accuracy, SSIM_R, vae_loss_R= sess.run([loss, rec_loss, accuracy, SSIM, vae_loss],feed_dict={x_vector:x})

        if vae_loss_R > 1e-6:
           sess.run([optimizer_op_encoder],feed_dict={x_vector: x})
     
        _,summary=sess.run([optimizer_op_decoder,merged_summary_op],feed_dict={x_vector: x})
    
        print("Epoch:", '%03d,' % (epoch + 1), 
              "ec_loss %.3f, dc_loss %.2f, PNSR %.3f, SSIM %.3f" % (vae_loss_R, rec_loss_R, train_accuracy, SSIM_R))

        # Write logs at every iteration
        summary_writer.add_summary(summary, epoch)
 
        if epoch % validation_step==0:
            idx = np.random.randint(len(volume_validation), size=batch_size)
            x_validation = volume_validation[idx]
            vali_accuracy, vali_SSIM= sess.run([accuracy, SSIM],feed_dict={x_vector:x_validation})
            pnsr_vali_store[mark]=vali_accuracy
            ssim_vali_store[mark]=vali_SSIM
            mark+=1
            print("Model validation: %s" % mark, "PNSR %.3f, SSIM %.3f" % (vali_accuracy, vali_SSIM)) 
            # Save the current model if the maximum accuracy is updated
            if train_accuracy > max_acc or epoch % (2*validation_step)==0:
                max_acc = train_accuracy
                save_path = saver.save(sess, MODEL_DIRECTORY)
                print("Model updated and saved in file: %s" % save_path)
 
    np.savetxt('pnsr.txt',pnsr_vali_store)
    np.savetxt('ssim.txt',ssim_vali_store)
    
    print("Optimization Finished!")
    print('Run `tensorboard --logdir=%s` to see the results.' % MODEL_DIRECTORY)


if __name__ == '__main__':
    
    start_time = time.time()
    
    trainGAN()
    
    end_time = time.time()
    
    print("This sampling run took %5.4f seconds." % (end_time - start_time))

    

